package connect

import (
	"errors"
	"fmt"
	"golang.org/x/crypto/ssh"
	"net"
	"time"
)

type JumpServer struct {
	IP         string
	SSHPort    int
	Username   string
	Password   string
	SSHKey     string
	Passphrase string        // ssh key 的密码
	Timeout    time.Duration // 超时时间
	SSHClient  *ssh.Client
}

func NewJumpServer(ip, username, password, sshKey, passphrase string, sshPort int, timeout time.Duration) *JumpServer {
	return &JumpServer{
		IP:         ip,
		SSHPort:    sshPort,
		Username:   username,
		Password:   password,
		SSHKey:     sshKey,
		Passphrase: passphrase,
		Timeout:    timeout,
	}
}

func (j *JumpServer) Open() error {
	auth, err := sshClientConfig(j.Password, j.SSHKey, j.Passphrase)
	if err != nil {
		return err
	}
	// 生成ssh client的配置
	config := &ssh.ClientConfig{
		User: j.Username,
		Auth: auth,
		HostKeyCallback: func(hostname string, remote net.Addr, key ssh.PublicKey) error {
			return nil
		},
		Timeout: j.Timeout,
	}
	j.SSHClient, err = ssh.Dial("tcp", fmt.Sprintf("%s:%d", j.IP, j.SSHPort), config)
	return err
}

func (j *JumpServer) dial(ip string, port int) (conn net.Conn, err error) {
	done := make(chan bool, 1)
	go func() {
		conn, err = j.SSHClient.Dial("tcp", fmt.Sprintf("%s:%d", ip, port))
		done <- true
	}()
	select {
	case <-done:
		return conn, err
	case <-time.After(j.Timeout):
		return nil, errors.New(fmt.Sprintf("bastion dial to %s:%d timeout", ip, port))
	}
}

func (j *JumpServer) Close() {
	_ = j.SSHClient.Close()
}
